"""Defines useful functions for other scripts to use."""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap
from rasterio.mask import mask
from scipy.stats import mode
from textwrap import wrap
import seaborn as sns
sns.set()

def recode(s, file):
    """Recodes categories.
    
    Args:
        s (pandas.Series)
        file (str): file path to a csv with two columns: "from" and "to"
    
    Returns:
        pandas.Series
    """
    recode_dict = pd.read_csv(file)
    recode_dict = {row['from']: row['to'] for _, row in recode_dict.iterrows()}
    return s.apply(lambda x: recode_dict[x])


def vote_by_majority(geometry, raster):
    """Aggregates categories in the raster to a polygon,
    based on the majority voting rule.

    Args:
        geometry (shapely.geometry.polygon.Polygon): The polygon of interest.
        raster (rasterio.io.DatasetReader): The background raster.
            The raster should consists of numpy.uint8 numbers indicating
            categories, where 0 refers to background.

    Returns:
        numpy.uint8: The category of the geometry.
    """
    # clip the raster with a geometry
    out_array, _ = mask(dataset=raster, shapes=[geometry], crop=True)
    # flatten the array to prepare for dropping zeros
    out_array = out_array.flatten()
    # drop zeros
    out_array = out_array[out_array != 0]
    if len(out_array) == 0:
        # if the raster does not overlap with the geometry
        # return 0 (background)
        return 0
    else:
        # otherwise
        # return the (first) mode of the array
        # this implicitly break ties
        return mode(out_array).mode[0]


def raster_to_shapefile(df, raster, col_name):
    """Converts a raster to a shapefile, based on the majority voting rule.

    Args:
        df (geopandas.GeoDataFrame): The data frame with the geometries of
            interest.
        raster (rasterio.io.DatasetReader): The background raster.
        col_name (str): The name of the newly created column to store the
            values extracted from the raster.

    Returns:
        geopandas.GeoDataFrame: The new data frame.
    """
    # drop empty geometries
    drop = (pd.Series(df.geometry.isnull()) | df.geometry.is_empty)
    print('Dropping {} empty geometries'.format(drop.sum()))
    output = df[~drop]
    # first reproject geometries to prepare for overlaying
    output = output.to_crs(raster.crs)
    # then overlay raster with shapefiles
    # and extract majority class in raster
    output[col_name] = output.geometry.apply(
        lambda x: vote_by_majority(x, raster))
    # drop 0's (background)
    drop = output[col_name] == 0
    print('Dropping {} non-overlapping geometries'.format(drop.sum()))
    output = output[~drop]
    # reproject to prevent file exporting errors
    output = output.to_crs(epsg=4326)
    return output


def convert(input_series, input_type, metadata_value, metadata_category):
    """Converts categories to values or values to categories.

    Args:
        input_series (pandas.Series): The categories or values to be converted.
        input_type (str): Either 'cat' or 'val'.
        metadata_value, metadata_category (iterable): Iterables of all
            the corresponding values and categories.

    Returns:
        output (pandas.Series): The converted values or categories.
    """
    # convert metadata into a dict
    if input_type == 'cat':
        d = {c: v for v, c in zip(metadata_value, metadata_category)}
    elif input_type == 'val':
        d = {v: c for v, c in zip(metadata_value, metadata_category)}
    else:
        raise NotImplementedError
    # convert cat to val or val to cat
    output_series = input_series.apply(lambda x: d[x])
    return output_series


def plot_crosstab(input_df, x_key, y_key, x_title, y_title, file_name,
                  drop_threshold=0, drop_strs=None, x_order=None, y_order=None,
                  **kwargs):
    """Plots a cross tabulation of two variables.

    Args:
        input_df (pandas.DataFrame): The data frame that stores all variables.
        x_key, y_key (str): Column names of variables of interest.
        x_title, y_title (str): Axis titles of variables of interest.
        file_name (str): Name of the file saved.
        drop_threshold (int): Threshold when categories are grouped into 'Other'
            instead of shown
        drop_strs (list of str): additional categories to group into 'Other'
        x_order, y_order (list of str): order of x/y axis categories
    """
    cmap = get_cmap('YlGnBu', 10)
    drop_strs = [] if drop_strs is None else drop_strs
    df = input_df.copy()
    for key, key_order in zip([x_key, y_key], (x_order, y_order)):
        # recode categories, those with x < threshold instances -> Other
        drop = df[key].value_counts()
        drop = drop[drop <= drop_threshold].index.tolist() + drop_strs
        df[key] = df[key].replace({cat: 'Other' for cat in drop})
        # sort the categories
        if key_order is not None:
            cats = df[key].unique()
            df[key] = pd.Categorical(df[key], key_order)
            # check: no missing from recoding
            assert df[key].notna().all(), f'{key} Categories:\n' + '\n'.join(cats)
    fig, ax = plt.subplots(figsize=(18, 10))
    crosstab = pd.crosstab(df[y_key], df[x_key])
    # aesthetic arguments are passed into the function through kwargs
    sns.heatmap(crosstab, ax=ax, annot=True, fmt='g', cmap=cmap, **kwargs)
    ax.set_xlabel(x_title)
    ax.set_ylabel(y_title)
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='left')
    ax.xaxis.tick_top()
    ax.xaxis.set_label_position('top')
    fig.tight_layout()  # this prevents clipping of the axis labels
    fig.savefig(file_name)
